# -*- coding: utf-8 -*-
import csv
import torch
import torch.utils.data as tud
from transformers import BertTokenizer
from torch.nn.utils.rnn import pad_sequence

TRAIN_DATA_PATH = '../data/train.tsv'
DEV_DATA_PATH = '../data/dev.tsv'
TOKENIZER_PATH = './bert-base-chinese'
MAX_LEN = 512
BATCH_SIZE = 32
PREFIX = '[MASK]满意。'

def collate_fn(batch_data):
    """
    DataLoader所需的collate_fun函数，将数据处理成tensor形式
    Args:
        batch_data: batch数据
    Returns:
    """
    input_ids_list, attention_mask_list, labels_list = [], [], []
    for instance in batch_data:
        # 按照batch中的最大数据长度,对数据进行padding填充
        input_ids_temp = instance["input_ids"]
        attention_mask_temp = instance["mask"]
        labels_temp = instance["labels"]
        # 添加到对应的list中
        input_ids_list.append(torch.tensor(input_ids_temp, dtype=torch.long))
        attention_mask_list.append(torch.tensor(attention_mask_temp, dtype=torch.long))
        labels_list.append(torch.tensor(labels_temp, dtype=torch.long))
    # 使用pad_sequence函数，会将list中所有的tensor进行长度补全，补全到一个batch数据中的最大长度，补全元素为padding_value
    return {"input_ids": pad_sequence(input_ids_list, batch_first=True, padding_value=0),
            "attention_mask": pad_sequence(attention_mask_list, batch_first=True, padding_value=0),
            "labels": pad_sequence(labels_list, batch_first=True, padding_value=-100)}

class BinarySentiDataset(tud.Dataset):
    def __init__(self, data_path, tokenizer_path, max_len, prefix):
        super(BinarySentiDataset, self).__init__()
        self.tokenizer = BertTokenizer.from_pretrained(tokenizer_path)
        self.max_len = max_len
        self.prefix = prefix
        self.pos_id = self.tokenizer.convert_tokens_to_ids('很')
        self.neg_id = self.tokenizer.convert_tokens_to_ids('不')
        
        self.data_set = []
        with open (data_path, 'r', encoding='utf8') as rf:
            r = csv.reader(rf, delimiter='\t')
            next(r)
            for row in r:         
                text = self.prefix + row[2]
                input_ids = self.tokenizer.encode(text)
                if len(input_ids) > self.max_len:
                    input_ids = input_ids[:self.max_len]                    
                target = int(row[1])
                if target == 0:
                    labels = [self.neg_id if idx == 103 else -100 for idx in input_ids]
                else:
                    labels = [self.pos_id if idx == 103 else -100 for idx in input_ids]               
                mask = [1] * len(input_ids)                
                self.data_set.append({"input_ids": input_ids, "mask": mask, "labels": labels})
               
    def __len__(self):
        return len(self.data_set)
    
    def __getitem__(self, idx):
        return self.data_set[idx]
        
traindataset = BinarySentiDataset(TRAIN_DATA_PATH, TOKENIZER_PATH, MAX_LEN, PREFIX)
traindataloader = tud.DataLoader(traindataset, BATCH_SIZE, shuffle=True, collate_fn=collate_fn)

valdataset = BinarySentiDataset(DEV_DATA_PATH, TOKENIZER_PATH, MAX_LEN, PREFIX)
valdataloader = tud.DataLoader(valdataset, BATCH_SIZE, shuffle=False, collate_fn=collate_fn)